{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8124bfdf-bd97-4814-8c80-560f4e0d2334",
   "metadata": {},
   "source": [
    "# Lesson 4 - Hybrid Search"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41241c42",
   "metadata": {},
   "source": [
    "### Import the Needed Packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de3c06eb-524a-4b35-9db2-f440392eccfc",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80a61ece-03da-4e87-9d0a-d99e494bbb71",
   "metadata": {
    "height": 165
   },
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "from pinecone_text.sparse import BM25Encoder\n",
    "from pinecone import Pinecone, ServerlessSpec\n",
    "from DLAIUtils import Utils\n",
    "\n",
    "from sentence_transformers import SentenceTransformer\n",
    "from tqdm.auto import tqdm\n",
    "import torch\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fea1366a-9ff9-4471-8eea-c88227b22789",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "utils = Utils()\n",
    "PINECONE_API_KEY = utils.get_pinecone_api_key()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52794a31",
   "metadata": {},
   "source": [
    "### Setup Pinecone"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ff87e06-9822-4528-82af-0c73ae11059d",
   "metadata": {
    "height": 318
   },
   "outputs": [],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "print(device)\n",
    "\n",
    "utils = Utils()\n",
    "INDEX_NAME = utils.create_dlai_index_name('dl-ai')\n",
    "\n",
    "pinecone = Pinecone(api_key=PINECONE_API_KEY)\n",
    "\n",
    "if INDEX_NAME in [index.name for index in pinecone.list_indexes()]:\n",
    "  pinecone.delete_index(INDEX_NAME)\n",
    "pinecone.create_index(\n",
    "  INDEX_NAME,\n",
    "  dimension=512,\n",
    "  metric=\"dotproduct\",\n",
    "  spec=ServerlessSpec(cloud='aws', region='us-west-2')\n",
    ")\n",
    "index = pinecone.Index(INDEX_NAME)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "257830a5",
   "metadata": {},
   "source": [
    "### Load the Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb9ea178-7737-4916-8eaa-24a824b41cb0",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "fashion = load_dataset(\n",
    "    \"ashraq/fashion-product-images-small\",\n",
    "    split=\"train\"\n",
    ")\n",
    "fashion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "909a738b-1c70-4b0d-8566-097d514b404d",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "images = fashion['image']\n",
    "metadata = fashion.remove_columns('image')\n",
    "images[900]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15faca59-729e-4cd4-9fc3-00feea488c9c",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "metadata = metadata.to_pandas()\n",
    "metadata.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82acbbac",
   "metadata": {},
   "source": [
    "### Create the Sparse Vector Using BM25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "798ac46a-acb8-4eda-84a2-74f96e68d960",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "bm25 = BM25Encoder()\n",
    "bm25.fit(metadata['productDisplayName'])\n",
    "metadata['productDisplayName'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb47ff8e-51d7-4dbb-b346-c1bde866e130",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "bm25.encode_queries(metadata['productDisplayName'][0])\n",
    "bm25.encode_documents(metadata['productDisplayName'][0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db48a9b1",
   "metadata": {},
   "source": [
    "### Create the Dense Vector Using CLIP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91284667-e998-4c17-bcd9-d70c069e9bb2",
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', \n",
    "    device=device)\n",
    "model\n",
    "dense_vec = model.encode([metadata['productDisplayName'][0]])\n",
    "dense_vec.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5b2755e-57b5-44d5-82de-3245f832182e",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "len(fashion)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "84e2c8f5",
   "metadata": {},
   "source": [
    "### Create Embeddings Using Sparse and Dense"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4eb8ec3-ec58-4a5b-ace2-a48e8b671f58",
   "metadata": {},
   "source": [
    "<p style=\"background-color:#fff1d7; padding:15px; \"> <b>(Note: <code>fashion_data_num = 1000</code>):</b> In this lab, we've initially set <code>fashion_data_num</code> to 1000 for speedier results, allowing you to observe the outcomes faster. Once you've done an initial run, consider increasing this value. You'll likely notice better and more relevant results.</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca6852c5-0e99-450a-b6fe-5def9da89781",
   "metadata": {
    "height": 607
   },
   "outputs": [],
   "source": [
    "batch_size = 100\n",
    "fashion_data_num = 1000\n",
    "\n",
    "for i in tqdm(range(0, min(fashion_data_num,len(fashion)), batch_size)):\n",
    "    # find end of batch\n",
    "    i_end = min(i+batch_size, len(fashion))\n",
    "    # extract metadata batch\n",
    "    meta_batch = metadata.iloc[i:i_end]\n",
    "    meta_dict = meta_batch.to_dict(orient=\"records\")\n",
    "    # concatinate all metadata field except for id and year to form a single string\n",
    "    meta_batch = [\" \".join(x) for x in meta_batch.loc[:, ~meta_batch.columns.isin(['id', 'year'])].values.tolist()]\n",
    "    # extract image batch\n",
    "    img_batch = images[i:i_end]\n",
    "    # create sparse BM25 vectors\n",
    "    sparse_embeds = bm25.encode_documents([text for text in meta_batch])\n",
    "    # create dense vectors\n",
    "    dense_embeds = model.encode(img_batch).tolist()\n",
    "    # create unique IDs\n",
    "    ids = [str(x) for x in range(i, i_end)]\n",
    "\n",
    "    upserts = []\n",
    "    # loop through the data and create dictionaries for uploading documents to pinecone index\n",
    "    for _id, sparse, dense, meta in zip(ids, sparse_embeds, dense_embeds, meta_dict):\n",
    "        upserts.append({\n",
    "            'id': _id,\n",
    "            'sparse_values': sparse,\n",
    "            'values': dense,\n",
    "            'metadata': meta\n",
    "        })\n",
    "    # upload the documents to the new hybrid index\n",
    "    index.upsert(upserts)\n",
    "\n",
    "# show index description after uploading the documents\n",
    "index.describe_index_stats()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7487f3c",
   "metadata": {},
   "source": [
    "### Run Your Query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "942d6de8-b967-4168-8406-b15c354bd58c",
   "metadata": {
    "height": 250
   },
   "outputs": [],
   "source": [
    "query = \"dark blue french connection jeans for men\"\n",
    "\n",
    "sparse = bm25.encode_queries(query)\n",
    "dense = model.encode(query).tolist()\n",
    "\n",
    "result = index.query(\n",
    "    top_k=14,\n",
    "    vector=dense,\n",
    "    sparse_vector=sparse,\n",
    "    include_metadata=True\n",
    ")\n",
    "\n",
    "imgs = [images[int(r[\"id\"])] for r in result[\"matches\"]]\n",
    "imgs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6be9256b-889d-4803-9f0d-206bbba08ab7",
   "metadata": {
    "height": 369
   },
   "outputs": [],
   "source": [
    "from IPython.core.display import HTML\n",
    "from io import BytesIO\n",
    "from base64 import b64encode\n",
    "\n",
    "# function to display product images\n",
    "def display_result(image_batch):\n",
    "    figures = []\n",
    "    for img in image_batch:\n",
    "        b = BytesIO()\n",
    "        img.save(b, format='png')\n",
    "        figures.append(f'''\n",
    "            <figure style=\"margin: 5px !important;\">\n",
    "              <img src=\"data:image/png;base64,{b64encode(b.getvalue()).decode('utf-8')}\" style=\"width: 90px; height: 120px\" >\n",
    "            </figure>\n",
    "        ''')\n",
    "    return HTML(data=f'''\n",
    "        <div style=\"display: flex; flex-flow: row wrap; text-align: center;\">\n",
    "        {''.join(figures)}\n",
    "        </div>\n",
    "    ''')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03260484-92f7-4394-9217-dc848d21e64f",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "display_result(imgs)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "898e32d0",
   "metadata": {},
   "source": [
    "### Scaling the Hybrid Search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc3ce832-f7c5-4592-ac91-18bc222f127c",
   "metadata": {
    "height": 369
   },
   "outputs": [],
   "source": [
    "def hybrid_scale(dense, sparse, alpha: float):\n",
    "    \"\"\"Hybrid vector scaling using a convex combination\n",
    "\n",
    "    alpha * dense + (1 - alpha) * sparse\n",
    "\n",
    "    Args:\n",
    "        dense: Array of floats representing\n",
    "        sparse: a dict of `indices` and `values`\n",
    "        alpha: float between 0 and 1 where 0 == sparse only\n",
    "               and 1 == dense only\n",
    "    \"\"\"\n",
    "    if alpha < 0 or alpha > 1:\n",
    "        raise ValueError(\"Alpha must be between 0 and 1\")\n",
    "    # scale sparse and dense vectors to create hybrid search vecs\n",
    "    hsparse = {\n",
    "        'indices': sparse['indices'],\n",
    "        'values':  [v * (1 - alpha) for v in sparse['values']]\n",
    "    }\n",
    "    hdense = [v * alpha for v in dense]\n",
    "    return hdense, hsparse"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a22e97f",
   "metadata": {},
   "source": [
    "### 1. More Dense"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "551b4bb5-3f5c-4130-91c1-936442358939",
   "metadata": {
    "height": 233
   },
   "outputs": [],
   "source": [
    "question = \"dark blue french connection jeans for men\"\n",
    "#Closer to 0==more sparse, closer to 1==more dense\n",
    "hdense, hsparse = hybrid_scale(dense, sparse, alpha=1)\n",
    "result = index.query(\n",
    "    top_k=6,\n",
    "    vector=hdense,\n",
    "    sparse_vector=hsparse,\n",
    "    include_metadata=True\n",
    ")\n",
    "imgs = [images[int(r[\"id\"])] for r in result[\"matches\"]]\n",
    "display_result(imgs)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe3a710e-19b0-4f0f-83db-eaaba499b332",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "for x in result[\"matches\"]:\n",
    "    print(x[\"metadata\"]['productDisplayName'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5109bec",
   "metadata": {},
   "source": [
    "### 2. More Sparse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46cc0b53-c422-4ae4-b417-66862c1f757a",
   "metadata": {
    "height": 199
   },
   "outputs": [],
   "source": [
    "question = \"dark blue french connection jeans for men\"\n",
    "#Closer to 0==more sparse, closer to 1==more dense\n",
    "hdense, hsparse = hybrid_scale(dense, sparse, alpha=0)\n",
    "result = index.query(\n",
    "    top_k=6,\n",
    "    vector=hdense,\n",
    "    sparse_vector=hsparse,\n",
    "    include_metadata=True\n",
    ")\n",
    "imgs = [images[int(r[\"id\"])] for r in result[\"matches\"]]\n",
    "display_result(imgs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f9c22fc-426e-46f0-ae6c-e2c9c2a66ea9",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "for x in result[\"matches\"]:\n",
    "    print(x[\"metadata\"]['productDisplayName'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac079a01",
   "metadata": {},
   "source": [
    "### More Dense or More Sparse?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cc6b0d9-646b-4f56-b6a7-a9a232ae5429",
   "metadata": {
    "height": 216
   },
   "outputs": [],
   "source": [
    "question = \"dark blue french connection jeans for men\"\n",
    "#Closer to 0==more sparse, closer to 1==more dense\n",
    "hdense, hsparse = hybrid_scale(dense, sparse, alpha=1)\n",
    "result = index.query(\n",
    "    top_k=6,\n",
    "    vector=hdense,\n",
    "    sparse_vector=hsparse,\n",
    "    include_metadata=True\n",
    ")\n",
    "imgs = [images[int(r[\"id\"])] for r in result[\"matches\"]]\n",
    "display_result(imgs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a51bcf3-bb5a-447d-96cf-789bb798e665",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "for x in result[\"matches\"]:\n",
    "    print(x[\"metadata\"]['productDisplayName'])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
